Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mlas int4 int8 with avx2/512 #20687

Merged
merged 48 commits into from
Aug 2, 2024
Merged

Mlas int4 int8 with avx2/512 #20687

merged 48 commits into from
Aug 2, 2024

Conversation

liqunfu
Copy link
Contributor

@liqunfu liqunfu commented May 15, 2024

Description

model: phi-3-mini-4k-instruct
avx2 symmetric

blklen updated prompt tps baseline prompt tps prompt tps change% updated token gen tps baseline token gen tps token gen change%
16 49.5 70.0 -29.2% 9.6 10.8 -34.2%
32 76.8 52.4 9.7% 15.2 14.6 4.1%
64 78.2 71.4 9.5% 16.6 16.3 1.8%
128 72.9 70.6 3.2% 17.1 16.8 1.7%
256 83.7 63.6 31.6% 18.1 17.4 4%

avx2 asymmetric

blklen updated prompt tps baseline prompt tps prompt tps change% updated token gen tps baseline token gen tps token gen change%
16 50.7 61.5 -17.5% 9.6 9.2 4.3%
32 77.4 52.4 47.7% 14.6 13.9 5.0%
64 78.7 63.0 24.9% 16.2 15.9 1.8%
128 80.0 61.9 29.2% 17.2 16.9 1.7%
256 81.5 63.3 28.7% 17.9 17.3 3.4%

avx2vnni symmetric

blklen updated prompt tps baseline prompt tps prompt tps change% updated token gen tps baseline token gen tps token gen change%
16 82.9 117.0 -29.0% 15.9 19.3 -17.6%
32 133.0 100.4 32.4% 26.1 24.5 6.5%
64 166.9 118.8 40.4% 28.3 27.1 4.4%
128 165.9 119.6 38.7% 29.3 28.5 2.8%
256 165.2 119.6 38.1% 30.2 29.0 4.1%

avx2vnni asymmetric

blklen updated prompt tps baseline prompt tps prompt tps change% updated token gen tps baseline token gen tps token gen change%
16 80.2 118.9 -32.5% 15.1 16.7 -9.5%
32 130.7 99.7 31.0% 25.0 23.8 5.0%
64 168.7 124.9 35.0% 27.3 26.8 1.8%
128 169.6 123.8 36.9% 29.2 27.9 4.6%
256 175.0 125.7 39.0% 30.0 29.7 1.0%

avx512 symmetric

blklen updated prompt tps baseline prompt tps prompt tps change% updated token gen tps baseline token gen tps token gen change%
16 135.2 156.5 -13.6% 25.5 23.8 7.1%
32 150.0 159.5 -5.9% 34.9 29.6 17.9%
64 167.5 157.5 6.3% 39.7 34.4 15.4%
128 177.8 158.0 12.5% 40.3 35.4 13.8%
256 182.6 157.3 16.0% 41.7 37.7 10.6%

avx512 asymmetric

blklen updated prompt tps baseline prompt tps prompt tps change% updated token gen tps baseline token gen tps token gen change%
16 136.1 151.4 -10.1% 26.1 19.9 31.1%
32 150.0 157.8 -4.9% 34.3 29.3 17.0%
64 165.7 156.6 5.8% 38.7 30.7 26.0%
128 180.4 156.6 15.1% 40.2 34.7 15.8%
256 181.3 158.0 14.7% 41.6 36.6 13.6%

avx512vnni symmetric

blklen updated prompt tps baseline prompt tps prompt tps change% updated token gen tps baseline token gen tps token gen change%
16 143.4 155.4 -7.7% 25.6 23.3 9.8%
32 159.2 157.0 1.4% 34.1 29.8 14.4%
64 182.0 159.5 14.1% 38.4 34.8 10.3%
128 221.2 160.8 37.5% 41.0 36.4 12.6%
256 250.5 162.4 54.2% 41.6 37.7 10.3%

avx512vnni asymmetric

blklen updated prompt tps baseline prompt tps prompt tps change% updated token gen tps baseline token gen tps token gen change%
16 142.5 152.3 -6.4% 26.3 19.7 33.5%
32 158.2 155.0 2.0% 34.3 29.2 17.4%
64 184.1 156.6 17.5% 38.3 30.9 23.9%
128 215.8 156.1 17.5% 41.3 35.0 17.9%
256 249.2 155.9 59.8% 41.1 36.3 13.2%

4bit gemm implementation with avx using tile.

tile size is 2blk by 4. in case of size less then tile, it reduce to 1blk by 4, 2blk by 1 and lastly 1blk by 1.
with internal kernel, weight and activation are loaded based on SIMD register width and blk length:
avx2 256bit register, 64 weights and activation are loaded.
blklen16: 4 blks are computed by the internal kernel
blklen32: 2 blks are computed by the internal kernel
blklen64: 1 blk are computed by the internal kernel
blklen128: 1 blks are computed 2 times by the internal kernel
blklen16: 1 blks are computed 4 times by the internal kernel

avx512 512bit register, 128 weights and activation are loaded.
blklen16: 8 blks are computed by the internal kernel
blklen32: 4 blks are computed by the internal kernel
blklen64: 2 blk are computed by the internal kernel
blklen128: 1 blks are computed by the internal kernel
blklen16: 1 blks are computed 2 times by the internal kernel

blksum is precomputed during prepacking.
computation is reformed:
Sum1(scale_a * scale_b * Sum_blk(a_i * b_i)) + Sum2(blksum_a * blksum_b)
Sum_blk is over one blk
Sum1 is over all blks for one output
Sum2 is over all blks for one output
Sum is computed with sgemm with the current implementation. Further improvement is possible.

 

liqunfu added 8 commits May 2, 2024 20:00
…en32, symmetric1 hasBias0 Int8

Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: Liqun Fu <[email protected]>
…tric:1/ComputeType:4/real_time_mean 1542487160 ns 1539062500 ns

Signed-off-by: Liqun Fu <[email protected]>
…048/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time_mean 1434872720 ns

Signed-off-by: Liqun Fu <[email protected]>
…NBITGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time_mean 1265060620 ns 1265625000 ns

Signed-off-by: Liqun Fu <[email protected]>
…TGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time_mean 1214042220 ns

Signed-off-by: Liqun Fu <[email protected]>
…6/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time_mean 784668090 ns; SQNBITGEMM<4>/BlkLen:64/M:2048/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time_mean 754939430 ns

Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: Liqun Fu <[email protected]>
@liqunfu liqunfu requested a review from a team as a code owner May 15, 2024 17:01
@liqunfu liqunfu marked this pull request as draft May 15, 2024 17:03
…ymmetric:1/ComputeType:4/real_time_mean 664029830 ns

Signed-off-by: liqunfu <[email protected]>
@liqunfu liqunfu changed the title Mlas int4 int8 with avx2 Mlas int4 int8 with avx2/512 May 26, 2024
Copy link

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PREfast found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

Signed-off-by: liqunfu <[email protected]>
liqunfu added 3 commits July 30, 2024 20:43
Signed-off-by: liqunfu <[email protected]>
@liqunfu liqunfu marked this pull request as ready for review July 30, 2024 21:37
liqunfu added 2 commits July 30, 2024 22:52
Signed-off-by: liqunfu <[email protected]>
Signed-off-by: liqunfu <[email protected]>
onnxruntime/core/mlas/inc/mlas_qnbit.h Outdated Show resolved Hide resolved
onnxruntime/core/mlas/inc/mlas_qnbit.h Show resolved Hide resolved
onnxruntime/core/mlas/inc/mlas_qnbit.h Outdated Show resolved Hide resolved
onnxruntime/core/mlas/lib/sqnbitgemm.cpp Outdated Show resolved Hide resolved
onnxruntime/core/mlas/lib/sqnbitgemm.h Outdated Show resolved Hide resolved
onnxruntime/core/mlas/lib/sqnbitgemm.h Show resolved Hide resolved
onnxruntime/core/mlas/lib/sqnbitgemm.cpp Show resolved Hide resolved
Copy link
Member

@yufenglee yufenglee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

@liqunfu liqunfu merged commit b87e8ed into main Aug 2, 2024
98 checks passed
@liqunfu liqunfu deleted the liqun/mlas-q4-tile-avx branch August 2, 2024 17:20
@prathikr prathikr added the release:1.19.0 Cherry pick to ORT 1.19 label Aug 2, 2024
prathikr pushed a commit that referenced this pull request Aug 3, 2024
### Description
model: phi-3-mini-4k-instruct
avx2 symmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |49.5|70.0|-29.2%|9.6|10.8|-34.2%
32 |76.8|52.4|9.7%|15.2|14.6|4.1%
64 |78.2|71.4|9.5%|16.6|16.3|1.8%
128 |72.9|70.6|3.2%|17.1|16.8|1.7%
256 |83.7|63.6|31.6%|18.1|17.4|4%

avx2 asymmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |50.7|61.5|-17.5%|9.6|9.2|4.3%
32 |77.4|52.4|47.7%|14.6|13.9|5.0%
64 |78.7|63.0|24.9%|16.2|15.9|1.8%
128 |80.0|61.9|29.2%|17.2|16.9|1.7%
256 |81.5|63.3|28.7%|17.9|17.3|3.4%

avx2vnni symmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |82.9|117.0|-29.0%|15.9|19.3|-17.6%
32 |133.0|100.4|32.4%|26.1|24.5|6.5%
64 |166.9|118.8|40.4%|28.3|27.1|4.4%
128 |165.9|119.6|38.7%|29.3|28.5|2.8%
256 |165.2|119.6|38.1%|30.2|29.0|4.1%

avx2vnni asymmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |80.2|118.9|-32.5%|15.1|16.7|-9.5%
32 |130.7|99.7|31.0%|25.0|23.8|5.0%
64 |168.7|124.9|35.0%|27.3|26.8|1.8%
128 |169.6|123.8|36.9%|29.2|27.9|4.6%
256 |175.0|125.7|39.0%|30.0|29.7|1.0%

avx512 symmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |135.2|156.5|-13.6|25.5|23.8|7.1
32 |150.0|159.5|-5.9|34.9|29.6|17.9
64 |167.5|157.5|6.3|39.7|34.4|15.4
128 |177.8|158.0|12.5|40.3|35.4|13.8
256 |182.6|157.3|16.0|41.7|37.7|10.6

avx512 asymmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |136.1|151.4|-10.1%|26.1|19.9|31.1%
32 |150.0|157.8|-4.9%|34.3|29.3|17.0%
64 |165.7|156.6|5.8%|38.7|30.7|26.0%
128 |180.4|156.6|15.1%|40.2|34.7|15.8%
256 |181.3|158.0|14.7%|41.6|36.6|13.6%

avx512vnni symmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |143.4|155.4|-7.7%|25.6|23.3|9.8%
32 |159.2|157.0|1.4%|34.1|29.8|14.4%
64 |182.0|159.5|14.1%|38.4|34.8|10.3%
128 |221.2|160.8|37.5%|41.0|36.4|12.6%
256 |250.5|162.4|54.2%|41.6|37.7|10.3%

avx512vnni asymmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |142.5|152.3|-6.4%|26.3|19.7|33.5%
32 |158.2|155.0|2.0%|34.3|29.2|17.4%
64 |184.1|156.6|17.5%|38.3|30.9|23.9%
128 |215.8|156.1|17.5%|41.3|35.0|17.9%
256 |249.2|155.9|59.8%|41.1|36.3|13.2%


4bit gemm implementation with avx using tile.

1.
tile size is 2blk by 4. in case of size less then tile, it reduce to
1blk by 4, 2blk by 1 and lastly 1blk by 1.
with internal kernel, weight and activation are loaded based on SIMD
register width and blk length:
avx2 256bit register, 64 weights and activation are loaded.
   blklen16: 4 blks are computed by the internal kernel
   blklen32: 2 blks are computed by the internal kernel
   blklen64: 1 blk are computed by the internal kernel
   blklen128: 1 blks are computed 2 times by the internal kernel
   blklen16: 1 blks are computed 4 times by the internal kernel

avx512 512bit register, 128 weights and activation are loaded.
   blklen16: 8 blks are computed by the internal kernel
   blklen32: 4 blks are computed by the internal kernel
   blklen64: 2 blk are computed by the internal kernel
   blklen128: 1 blks are computed by the internal kernel
   blklen16: 1 blks are computed 2 times by the internal kernel

2.
blksum is precomputed during prepacking. 
computation is reformed:
Sum1(scale_a * scale_b * Sum_blk(a_i * b_i)) + Sum2(blksum_a * blksum_b)
  Sum_blk is over one blk
  Sum1 is over all blks for one output
  Sum2 is over all blks for one output
Sum is computed with sgemm with the current implementation. Further
improvement is possible.

 

---------

Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: liqunfu <[email protected]>
Signed-off-by: Liqun Fu <[email protected]>
prathikr pushed a commit that referenced this pull request Aug 5, 2024
### Description
model: phi-3-mini-4k-instruct
avx2 symmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |49.5|70.0|-29.2%|9.6|10.8|-34.2%
32 |76.8|52.4|9.7%|15.2|14.6|4.1%
64 |78.2|71.4|9.5%|16.6|16.3|1.8%
128 |72.9|70.6|3.2%|17.1|16.8|1.7%
256 |83.7|63.6|31.6%|18.1|17.4|4%

avx2 asymmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |50.7|61.5|-17.5%|9.6|9.2|4.3%
32 |77.4|52.4|47.7%|14.6|13.9|5.0%
64 |78.7|63.0|24.9%|16.2|15.9|1.8%
128 |80.0|61.9|29.2%|17.2|16.9|1.7%
256 |81.5|63.3|28.7%|17.9|17.3|3.4%

avx2vnni symmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |82.9|117.0|-29.0%|15.9|19.3|-17.6%
32 |133.0|100.4|32.4%|26.1|24.5|6.5%
64 |166.9|118.8|40.4%|28.3|27.1|4.4%
128 |165.9|119.6|38.7%|29.3|28.5|2.8%
256 |165.2|119.6|38.1%|30.2|29.0|4.1%

avx2vnni asymmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |80.2|118.9|-32.5%|15.1|16.7|-9.5%
32 |130.7|99.7|31.0%|25.0|23.8|5.0%
64 |168.7|124.9|35.0%|27.3|26.8|1.8%
128 |169.6|123.8|36.9%|29.2|27.9|4.6%
256 |175.0|125.7|39.0%|30.0|29.7|1.0%

avx512 symmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |135.2|156.5|-13.6|25.5|23.8|7.1
32 |150.0|159.5|-5.9|34.9|29.6|17.9
64 |167.5|157.5|6.3|39.7|34.4|15.4
128 |177.8|158.0|12.5|40.3|35.4|13.8
256 |182.6|157.3|16.0|41.7|37.7|10.6

avx512 asymmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |136.1|151.4|-10.1%|26.1|19.9|31.1%
32 |150.0|157.8|-4.9%|34.3|29.3|17.0%
64 |165.7|156.6|5.8%|38.7|30.7|26.0%
128 |180.4|156.6|15.1%|40.2|34.7|15.8%
256 |181.3|158.0|14.7%|41.6|36.6|13.6%

avx512vnni symmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |143.4|155.4|-7.7%|25.6|23.3|9.8%
32 |159.2|157.0|1.4%|34.1|29.8|14.4%
64 |182.0|159.5|14.1%|38.4|34.8|10.3%
128 |221.2|160.8|37.5%|41.0|36.4|12.6%
256 |250.5|162.4|54.2%|41.6|37.7|10.3%

avx512vnni asymmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |142.5|152.3|-6.4%|26.3|19.7|33.5%
32 |158.2|155.0|2.0%|34.3|29.2|17.4%
64 |184.1|156.6|17.5%|38.3|30.9|23.9%
128 |215.8|156.1|17.5%|41.3|35.0|17.9%
256 |249.2|155.9|59.8%|41.1|36.3|13.2%


4bit gemm implementation with avx using tile.

1.
tile size is 2blk by 4. in case of size less then tile, it reduce to
1blk by 4, 2blk by 1 and lastly 1blk by 1.
with internal kernel, weight and activation are loaded based on SIMD
register width and blk length:
avx2 256bit register, 64 weights and activation are loaded.
   blklen16: 4 blks are computed by the internal kernel
   blklen32: 2 blks are computed by the internal kernel
   blklen64: 1 blk are computed by the internal kernel
   blklen128: 1 blks are computed 2 times by the internal kernel
   blklen16: 1 blks are computed 4 times by the internal kernel

avx512 512bit register, 128 weights and activation are loaded.
   blklen16: 8 blks are computed by the internal kernel
   blklen32: 4 blks are computed by the internal kernel
   blklen64: 2 blk are computed by the internal kernel
   blklen128: 1 blks are computed by the internal kernel
   blklen16: 1 blks are computed 2 times by the internal kernel

2.
blksum is precomputed during prepacking. 
computation is reformed:
Sum1(scale_a * scale_b * Sum_blk(a_i * b_i)) + Sum2(blksum_a * blksum_b)
  Sum_blk is over one blk
  Sum1 is over all blks for one output
  Sum2 is over all blks for one output
Sum is computed with sgemm with the current implementation. Further
improvement is possible.

 

---------

Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: liqunfu <[email protected]>
Signed-off-by: Liqun Fu <[email protected]>
@prathikr prathikr added the cherry-picked Cherry-picked for a cherrypicks branch label Aug 6, 2024
}
} // k_blks_remaining

if constexpr (NCols4 == 8) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this "if constexpr" be removed? Since NCols4 can only be 4 at here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cherry-picked Cherry-picked for a cherrypicks branch release:1.19.0 Cherry pick to ORT 1.19
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants